# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right


class Solution:
    def longestUnivaluePath(self, root: Optional[TreeNode]) -> int:

        def helper(root, count):
            if root is None:
                return count
            
            l, r = 0, 0
            if root.left and root.left.val == root.val:
                l = helper(root.left,count+1)
            if root.right and root.right.val == root.val:
                r = helper(root.right,count+1)
            
            return max(l, r)
        
        if root is None:
            return 0
        
        self.ans = 0
        
        def traverse(root):
            if root is None:
                return 
            
            l, r = 0, 0
            if root.left and root.left.val == root.val:
                l = helper(root.left, 1)
            if root.right and root.right.val == root.val:
                r = helper(root.right, 1)
                
            self.ans = max(self.ans, l+r+1)

            traverse(root.left)
            traverse(root.right)
        
        traverse(root)
        return self.ans - 1

            
            

            
    